package com.jeasy.base.resolver;

import javax.servlet.ServletRequest;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.core.MethodParameter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;
import org.springframework.web.servlet.mvc.method.annotation.ExtendedServletRequestDataBinder;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;

/**
 * @author taomk
 * @version 1.0
 * @since 15-8-4 下午4:45
 */
public class ArgumentFromJsonResolver implements HandlerMethodArgumentResolver {

	@Override
	public boolean supportsParameter(MethodParameter parameter) {
		return true;
	}

	@Override
	public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception {
		return bindParametersToRequest(parameter, mavContainer, webRequest, binderFactory);
	}

	private Object bindParametersToRequest(MethodParameter parameter, ModelAndViewContainer mavContainer, NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception {
		String alias = getAlias(parameter);

		// 拿到obj, 先从ModelAndViewContainer中拿，若没有则创建参数类型的实例
		Object obj = (mavContainer.containsAttribute(alias)) ? mavContainer.getModel().get(alias) : createAttribute(parameter.getParameterType());
		// 获得WebDataBinder，这里的具体WebDataBinder是ExtendedServletRequestDataBinder
		WebDataBinder binder = binderFactory.createBinder(webRequest, obj, alias);
		Object target = binder.getTarget();
		MockHttpServletRequest newRequest = new MockHttpServletRequest();

		boolean isFlag = parameter.getParameterType().equals(String.class)
				|| parameter.getParameterType().equals(Integer.class)
				|| parameter.getParameterType().equals(Long.class)
				|| parameter.getParameterType().equals(Double.class)
				|| parameter.getParameterType().equals(Float.class)
				|| parameter.getParameterType().equals(List.class)
				|| parameter.getParameterType().equals(Set.class)
				|| parameter.getParameterType().isArray();

		if(target != null) {
			// 根据当前parameter，绑定参数
			bindParameters(webRequest, binder, newRequest, parameter, isFlag);
		}

		if (isFlag) {
			// 基本包装类型
			return buildRequestToTarget(alias, newRequest, parameter, target);
		} else {
			// Object类型
			return target;
		}
	}

	private Object buildRequestToTarget(String alias, MockHttpServletRequest newRequest, MethodParameter parameter, Object target) throws IllegalAccessException, InstantiationException {
		if (parameter.getParameterType().equals(String.class)) {
			return newRequest.getParameter(alias);
		} else if (parameter.getParameterType().equals(Integer.class)) {
			return newRequest.getParameter(alias) == null ? null : Integer.valueOf(newRequest.getParameter(alias));
		} else if (parameter.getParameterType().equals(Long.class)) {
			return newRequest.getParameter(alias) == null ? null : Long.valueOf(newRequest.getParameter(alias));
		} else if (parameter.getParameterType().equals(Double.class)) {
			return newRequest.getParameter(alias) == null ? null : Double.valueOf(newRequest.getParameter(alias));
		} else if (parameter.getParameterType().equals(Float.class)) {
			return newRequest.getParameter(alias) == null ? null : Float.valueOf(newRequest.getParameter(alias));
		} else if (parameter.getParameterType().equals(List.class)) {
			String [] values = newRequest.getParameterValues(alias);
			if (values != null && values.length > 0) {
				ParameterizedType type = (ParameterizedType) parameter.getGenericParameterType();
				Type [] types = type.getActualTypeArguments();
				if (types.length == 1) {
					if (types[0].equals(Integer.class)) {
						List<Integer> resultTarget = Lists.newArrayList();
						for (String val : values) {
							resultTarget.add(Integer.parseInt(val));
						}
						return resultTarget;
					} else if (types[0].equals(Long.class)) {
						List<Long> resultTarget = Lists.newArrayList();
						for (String val : values) {
							resultTarget.add(Long.parseLong(val));
						}
						return resultTarget;
					} else if (types[0].equals(Double.class)) {
						List<Double> resultTarget = Lists.newArrayList();
						for (String val : values) {
							resultTarget.add(Double.parseDouble(val));
						}
						return resultTarget;
					} else if (types[0].equals(Float.class)) {
						List<Float> resultTarget = Lists.newArrayList();
						for (String val : values) {
							resultTarget.add(Float.parseFloat(val));
						}
						return resultTarget;
					}
				}
				return Lists.newArrayList(values);
			}
			return null;
		} else if (parameter.getParameterType().equals(Set.class)) {
			String [] values = newRequest.getParameterValues(alias);
			if (values != null && values.length > 0) {
				ParameterizedType type = (ParameterizedType) parameter.getGenericParameterType();
				Type [] types = type.getActualTypeArguments();
				if (types.length == 1) {
					if (types[0].equals(Integer.class)) {
						Set<Integer> resultTarget = Sets.newHashSet();
						for (String val : values) {
							resultTarget.add(Integer.parseInt(val));
						}
						return resultTarget;
					} else if (types[0].equals(Long.class)) {
						Set<Long> resultTarget = Sets.newHashSet();
						for (String val : values) {
							resultTarget.add(Long.parseLong(val));
						}
						return resultTarget;
					} else if (types[0].equals(Double.class)) {
						Set<Double> resultTarget = Sets.newHashSet();
						for (String val : values) {
							resultTarget.add(Double.parseDouble(val));
						}
						return resultTarget;
					} else if (types[0].equals(Float.class)) {
						Set<Float> resultTarget = Sets.newHashSet();
						for (String val : values) {
							resultTarget.add(Float.parseFloat(val));
						}
						return resultTarget;
					}
				}
				return Sets.newHashSet(values);
			}
			return null;
		} else if (parameter.getParameterType().isArray()) {
			String [] values = newRequest.getParameterValues(alias);
			if (values != null && values.length > 0) {
				Class genericClass = parameter.getParameterType().getComponentType();
				if (genericClass.equals(Integer.class)) {
					Integer[] resultTarget = new Integer[values.length];
					for (int i = 0; i < resultTarget.length; i++) {
						resultTarget[i] = Integer.parseInt(values[i]);
					}
					return resultTarget;
				} else if (genericClass.equals(Long.class)) {
					Long[] resultTarget = new Long[values.length];
					for (int i = 0; i < resultTarget.length; i++) {
						resultTarget[i] = Long.parseLong(values[i]);
					}
					return resultTarget;
				} else if (genericClass.equals(Double.class)) {
					Double[] resultTarget = new Double[values.length];
					for (int i = 0; i < resultTarget.length; i++) {
						resultTarget[i] = Double.parseDouble(values[i]);
					}
					return resultTarget;
				} else if (genericClass.equals(Float.class)) {
					Float[] resultTarget = new Float[values.length];
					for (int i = 0; i < resultTarget.length; i++) {
						resultTarget[i] = Float.parseFloat(values[i]);
					}
					return resultTarget;
				}
				return values;
			}
			return null;
		} else {
			return target;
		}
	}

	private Object createAttribute(Class<?> parameterType) {
		if (parameterType.equals(String.class)) {
			return "";
		} else if (parameterType.equals(Integer.class)) {
			return 0;
		} else if (parameterType.equals(Long.class)) {
			return 0l;
		} else if (parameterType.equals(Double.class)) {
			return (double) 0;
		} else if (parameterType.equals(Float.class)) {
			return (float) 0;
		} else if (parameterType.equals(List.class)) {
			return Lists.newArrayList();
		} else if (parameterType.equals(Set.class)) {
			return Sets.newHashSet();
		} else if (parameterType.isArray()) {
			return new Object [0];
		} else {
			return BeanUtils.instantiateClass(parameterType);
		}
	}

	private void bindParameters(NativeWebRequest request, WebDataBinder binder, MockHttpServletRequest newRequest, MethodParameter parameter, boolean isFlag) {
		ServletRequest servletRequest = request.getNativeRequest(ServletRequest.class);
		Enumeration enu = servletRequest.getParameterNames();

		// 基本包装类型
		if (isFlag) {
			String paramName = getAlias(parameter);
			String [] paramVals = request.getParameterValues(paramName);
			// 优先获取KV参数
			if (paramVals != null && paramVals.length > 0) {
				newRequest.setParameter(paramName, paramVals);
			} else {
				// KV参数没有，则会再解析JSON参数
				while(enu.hasMoreElements()) {
					paramName = (String) enu.nextElement();
					paramVals = request.getParameterValues(paramName);

					if (paramVals.length == 1 && paramVals[0].startsWith("{\"") && paramVals[0].endsWith("}")) {
						JsonObject jsonObject =  new JsonParser().parse(paramVals[0]).getAsJsonObject();
						bindJsonToRequest(newRequest, jsonObject, StringUtils.EMPTY);
					}
				}
			}
		} else {
			// Object类型且带有FromJson注解 直接解析JSON参数
			if (parameter.hasParameterAnnotation(FromJson.class)) {
				while(enu.hasMoreElements()) {
					String paramName = (String) enu.nextElement();
					String[] paramVals = request.getParameterValues(paramName);

					if (paramVals.length == 1 && paramVals[0].startsWith("{\"") && paramVals[0].endsWith("}")) {
						JsonObject jsonObject =  new JsonParser().parse(paramVals[0]).getAsJsonObject();
						bindJsonToRequest(newRequest, jsonObject, StringUtils.EMPTY);
					}
				}
			} else {
				// Object类型且未带有FromJson注解 直接解析KV参数
				while(enu.hasMoreElements()) {
					String paramName = (String) enu.nextElement();
					String[] paramVals = request.getParameterValues(paramName);
					newRequest.setParameter(paramName, paramVals);
				}
			}

			((ExtendedServletRequestDataBinder)binder).bind(newRequest);
		}
	}

	private void bindJsonToRequest(MockHttpServletRequest newRequest, JsonElement jsonElement, String key) {
		if (jsonElement.isJsonPrimitive()) {
			String[] paramVals = newRequest.getParameterValues(key);
			if (paramVals != null && paramVals.length > 0) {
				String[] newParamVals = new String[paramVals.length + 1];
				int i = 0;
				for (String val : paramVals) {
					newParamVals[i++] = val;
				}
				newParamVals[i+1] = jsonElement.getAsString();

				newRequest.setParameter(key, newParamVals);
			} else {
				newRequest.setParameter(key, jsonElement.getAsString());
			}
		} else if (jsonElement.isJsonObject()) {
			key = StringUtils.isNotBlank(key) ? key + "." : key;
			for (Map.Entry<String, JsonElement> entry : jsonElement.getAsJsonObject().entrySet()) {
				bindJsonToRequest(newRequest, entry.getValue(), key + entry.getKey());
			}
		} else if (jsonElement.isJsonArray()) {
			JsonArray jsonArray = jsonElement.getAsJsonArray();
			for (int i = 0; i < jsonArray.size(); i++) {
				bindJsonToRequest(newRequest, jsonArray.get(i), key + "[" + i + "]");
			}
		}
	}

	/**
	 * 获取对象参数的简称
	 * @param parameter
	 * @return
	 */
	private String getAlias(MethodParameter parameter) {
		String alias = null;
		if (parameter.hasParameterAnnotation(FromJson.class)) {
			if (parameter.getParameterType().equals(String.class)
					|| parameter.getParameterType().equals(Integer.class)
					|| parameter.getParameterType().equals(Long.class)
					|| parameter.getParameterType().equals(Double.class)
					|| parameter.getParameterType().equals(Float.class)
					|| parameter.getParameterType().equals(List.class)
					|| parameter.getParameterType().equals(Set.class)
					|| parameter.getParameterType().isArray()) {
				alias = parameter.getParameterName();
			} else {
				alias = parameter.getParameterAnnotation(FromJson.class).key();
			}
		}

		if(StringUtils.isBlank(alias)) {
			alias = parameter.getParameterName();
		}
		return alias;
	}
}